import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np

import matplotlib as mpl
from mpl_toolkits.axes_grid1 import make_axes_locatable

color_list=['tab:green','tab:orange','tab:blue']
num_trials=50

def env():
  a=np.loadtxt("expert_trajectory_file.txt",dtype=float)
  trajectories=a.reshape(35*num_trials,6)
  state_heat=np.zeros((13,10))
  for i in range(num_trials):
    trajectory=trajectories[35*i:35*(i+1),:]
    for j in range(35):
      state_heat[int(trajectory[j,1]),int(trajectory[j,0])]=state_heat[int(trajectory[j,1]),int(trajectory[j,0])]+1
      state_heat[int(trajectory[j,3]),int(trajectory[j,2])]=state_heat[int(trajectory[j,3]),int(trajectory[j,2])]+1
  state_heat=state_heat/100.0
  state_heat[0,3]=1.0
  state_heat[0,9]=1.0
  state_heat[1,8]=1.0
  state_heat[1,9]=1.0
  state_heat[11,8]=1.0
  state_heat[12,9]=1.0
  state_heat[12,3]=1.0
  state_heat[12,4]=1.0
  #print(state_heat)

  fig,ax=plt.subplots()
  ax.axis('scaled')
  ax.set_xticks(np.linspace(0,9,10))
  ax.set_yticks(np.linspace(0,12,13))
  ax.axis([0,10,0,13])
  ax.grid(linestyle='-',color='black')
  plt.plot([0,10], [0,0], color='black', linewidth=2)
  plt.plot([3,10], [2,2], color='black', linewidth=2)
  plt.plot([0,1], [1,1], color='black', linewidth=2)
  plt.plot([0,0], [0,1], color='black', linewidth=2)
  plt.plot([0,0], [6,7], color='black', linewidth=2)
  plt.plot([1,1], [1,6], color='black', linewidth=2)
  plt.plot([1,1], [7,12], color='black', linewidth=2)
  plt.plot([0,1], [12,12], color='black', linewidth=2)
  plt.plot([3,10], [11,11], color='black', linewidth=2)
  plt.plot([3,3], [2,11], color='black', linewidth=2)
  plt.plot([0,1], [6,6], color='black', linewidth=2)
  plt.plot([0,1], [7,7], color='black', linewidth=2)
  plt.plot([0,0], [12,13], color='black', linewidth=2)
  plt.plot([0,10], [13,13], color='black', linewidth=2)
  plt.plot([10,10], [0,2], color='black', linewidth=2)
  plt.plot([10,10], [11,13], color='black', linewidth=2)
  plt.plot([8,8], [0,1], color='black', linewidth=2)
  plt.plot([9,9], [0,1], color='black', linewidth=2)
  plt.plot([8,9], [1,1], color='black', linewidth=2)
  plt.plot([3,3], [11,12], color='black', linewidth=2)
  plt.plot([4,4], [11,12], color='black', linewidth=2)
  plt.plot([3,4], [12,12], color='black', linewidth=2)
  plt.plot([2,2], [1,2], color='black', linewidth=2)
  plt.plot([3,3], [1,2], color='black', linewidth=2)
  plt.plot([2,3], [1,1], color='black', linewidth=2)
  plt.plot([2,3], [2,2], color='black', linewidth=2)
  plt.plot([7,7], [12,13], color='black', linewidth=2)
  plt.plot([8,8], [12,13], color='black', linewidth=2)
  plt.plot([7,8], [12,12], color='black', linewidth=2)

  #obstacle1=plt.Rectangle((8,0),1,1,facecolor='none',hatch='//')
  #obstacle2=plt.Rectangle((3,11),1,1,facecolor='none',hatch='//')
  #obstacle3=plt.Rectangle((2,1),1,1,facecolor='none',hatch='//')
  #obstacle4=plt.Rectangle((1,3),1,1,facecolor='none',hatch='//')
  #obstacle5=plt.Rectangle((7,12),1,1,facecolor='none',hatch='//')
  #obstacle6=plt.Rectangle((3,2),7,9,facecolor='none',hatch='//')
  for x in range(3,10):
   for y in range(2,11):
     ax.scatter(x+0.5,y+0.5,s=160,c="r",marker="x")
  ax.scatter(8.5,0.5,s=160,c="r",marker="x")
  ax.scatter(2.5,1.5,s=160,c="r",marker="x")
  ax.scatter(3.5,11.5,s=160,c="r",marker="x")
  ax.scatter(7.5,12.5,s=160,c="r",marker="x")
  for y in range(1,6):
    ax.scatter(0.5,y+0.5,s=160,c="r",marker="x")
  for y in range(7,12):
    ax.scatter(0.5,y+0.5,s=160,c="r",marker="x")
  #obstacle7=plt.Rectangle((0,1),1,5,facecolor='none',hatch='//')
  #obstacle8=plt.Rectangle((0,7),1,5,facecolor='none',hatch='//')
  #ax.add_patch(obstacle1)
  #ax.add_patch(obstacle2)
  #ax.add_patch(obstacle3)
  #ax.add_patch(obstacle4)
  #ax.add_patch(obstacle5)
  #ax.add_patch(obstacle6)
  #ax.add_patch(obstacle7)
  #ax.add_patch(obstacle8)
  #ax.scatter(2.5,3.5,s=160,c="r",marker="x")
  ax.text(0.25,0.35,'$L2$',color='black',fontsize=10)
  ax.text(0.25,12.35,'$L1$',color='black',fontsize=10)
  ax.text(9.25,0.35,'$E2$',color='black',fontsize=10)
  ax.text(9.25,12.35,'$E1$',color='black',fontsize=10)
  ax.text(0.35,6.35,'$G$',color='black',fontsize=10)

  #im=ax.imshow(state_heat,cmap='viridis',extent=[0,10,13,0])
  #cb=plt.colorbar(im, ax=[ax],fraction=0.046, pad=0.04,location='left')
  #cb.set_label('Visitation Frequency (Scaled)')
  #ax.text(1.35,12.35,'$s_G^{\prime}$',fontsize=10)
  for axi in (ax.xaxis, ax.yaxis):
      for tic in axi.get_major_ticks():
          tic.tick1On = tic.tick2On = False
          tic.label1On = tic.label2On = False
  plt.show()

env()

def dynamics(state,action):   # input and output are matrices
  x=np.copy(state[0].item())
  y=np.copy(state[1].item())
  a=action.item()
  if a==0 and y<12:
    y=y+1
  if a==1 and y>0:
    y=y-1
  if a==2 and x>0:
    x=x-1
  if a==3 and x<9:
    x=x+1
  return np.mat([x,y]).T


















